{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Training and Managing MNIST Predictions in SuperDuperDB"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "4b24af19"
  },
  {
   "cell_type": "markdown",
   "id": "8905783f",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "This notebook guides you through the implementation of a classic machine learning task: MNIST handwritten digit recognition. The twist? We perform the task directly in a database using SuperDuperDB.\n",
    "\n",
    "This example makes it easy to connect any of your image recognition model directly to your database in real-time. With SuperDuperDB, you can skip complicated MLOps pipelines. It's a new straightforward way to integrate your AI model with your data, ensuring simplicity, efficiency and speed. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95f897a45b2a02cc",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## Prerequisites\n",
    "\n",
    "Before diving into the implementation, ensure that you have the necessary libraries installed by running the following commands:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9897997-dee8-4947-9327-b96fe06a5a2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install superduperdb\n",
    "!pip install torch torchvision matplotlib numpy==1.24.4"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3812091",
   "metadata": {},
   "source": [
    "## Connect to datastore \n",
    "\n",
    "First, we need to establish a connection to a MongoDB datastore via SuperDuperDB. You can configure the `MongoDB_URI` based on your specific setup. \n",
    "\n",
    "Here are some examples of MongoDB URIs:\n",
    "\n",
    "* For testing (default connection): `mongomock://test`\n",
    "* Local MongoDB instance: `mongodb://localhost:27017`\n",
    "* MongoDB with authentication: `mongodb://superduper:superduper@mongodb:27017/documents`\n",
    "* MongoDB Atlas: `mongodb+srv://<username>:<password>@<atlas_cluster>/<database>`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a28adbce",
   "metadata": {},
   "outputs": [],
   "source": [
    "from superduperdb import superduper\n",
    "from superduperdb.backends.mongodb import Collection\n",
    "import os\n",
    "\n",
    "mongodb_uri = os.getenv(\"MONGODB_URI\",\"mongomock://test\")\n",
    "\n",
    "# SuperDuperDB, now handles your MongoDB database\n",
    "# It just super dupers your database \n",
    "db = superduper(mongodb_uri, artifact_store='filesystem://./data/')\n",
    "\n",
    "# Create a collection for MNIST\n",
    "mnist_collection = Collection('mnist')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6233e891",
   "metadata": {},
   "source": [
    "## Load Dataset\n",
    "\n",
    "After establishing a connection to MongoDB, the next step is to load the MNIST dataset. SuperDuperDB's strength lies in handling diverse data types, especially those that are challenging. To achieve this, we use an `Encoder` in conjunction with `Document` wrappers. These components allow Python dictionaries containing non-JSONable or bytes objects to be seamlessly inserted into the underlying data infrastructure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf0934cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "from superduperdb.ext.pillow import pil_image\n",
    "from superduperdb import Document\n",
    "from superduperdb.backends.mongodb import Collection\n",
    "\n",
    "import random\n",
    "\n",
    "# Load MNIST images as Python objects using the Python Imaging Library.\n",
    "# Each MNIST item is a tuple (image, label)\n",
    "mnist_data = list(torchvision.datasets.MNIST(root='./data', download=True))\n",
    "\n",
    "# Create a list of Document instances from the MNIST data\n",
    "# Each Document has an 'img' field (encoded using the Pillow library) and a 'class' field\n",
    "document_list = [Document({'img': pil_image(x[0]), 'class': x[1]}) for x in mnist_data]\n",
    "\n",
    "# Shuffle the data and select a subset of 1000 documents\n",
    "random.shuffle(document_list)\n",
    "data = document_list[:1000]\n",
    "\n",
    "# Insert the selected data into the mnist_collection which we mentioned before like: mnist_collection = Collection('mnist')\n",
    "db.execute(\n",
    "    mnist_collection.insert_many(data[:-100]),  # Insert all but the last 100 documents\n",
    "    encoders=(pil_image,) # Encode images using the Pillow library.\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5341135",
   "metadata": {},
   "source": [
    "Now that the images and their classes are inserted into the database, we can query the data in its original format. Particularly, we can use the `PIL.Image` instances to inspect the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a36f9c3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get and display one of the images\n",
    "r = db.execute(mnist_collection.find_one())\n",
    "r.unpack()['img']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1413d4c5",
   "metadata": {},
   "source": [
    "## Build Model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68fde8bb",
   "metadata": {},
   "source": [
    "Following that, we build our machine learning model. SuperDuperDB conveniently supports various frameworks, and for this example, we opt for PyTorch, a suitable choice for computer vision tasks. In this instance, we combine `torch` with `torchvision`.\n",
    "\n",
    "To facilitate communication with the SuperDuperDB `Datalayer`, we design `postprocess` and `preprocess` functions. These functions are then encapsulated with the model, preprocessing, and postprocessing steps to create a native SuperDuperDB handler."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfb425e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "# Define the LeNet-5 architecture for image classification\n",
    "class LeNet5(torch.nn.Module):\n",
    "    def __init__(self, num_classes):\n",
    "        super().__init__()\n",
    "        # Layer 1\n",
    "        self.layer1 = torch.nn.Sequential(\n",
    "            torch.nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0),\n",
    "            torch.nn.BatchNorm2d(6),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.MaxPool2d(kernel_size=2, stride=2))\n",
    "        # Layer 2\n",
    "        self.layer2 = torch.nn.Sequential(\n",
    "            torch.nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),\n",
    "            torch.nn.BatchNorm2d(16),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.MaxPool2d(kernel_size=2, stride=2))\n",
    "        # Fully connected layers\n",
    "        self.fc = torch.nn.Linear(400, 120)\n",
    "        self.relu = torch.nn.ReLU()\n",
    "        self.fc1 = torch.nn.Linear(120, 84)\n",
    "        self.relu1 = torch.nn.ReLU()\n",
    "        self.fc2 = torch.nn.Linear(84, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.layer1(x)\n",
    "        out = self.layer2(out)\n",
    "        out = out.reshape(out.size(0), -1)\n",
    "        out = self.fc(out)\n",
    "        out = self.relu(out)\n",
    "        out = self.fc1(out)\n",
    "        out = self.relu1(out)\n",
    "        out = self.fc2(out)\n",
    "        return out\n",
    "\n",
    "# Postprocess function for the model output    \n",
    "def postprocess(x):\n",
    "    return int(x.topk(1)[1].item())\n",
    "\n",
    "# Preprocess function for input data\n",
    "def preprocess(x):\n",
    "    return torchvision.transforms.Compose([\n",
    "        torchvision.transforms.Resize((32, 32)),\n",
    "        torchvision.transforms.ToTensor(),\n",
    "        torchvision.transforms.Normalize(mean=(0.1307,), std=(0.3081,))]\n",
    "    )(x)\n",
    "\n",
    "# Create an instance of the LeNet-5 model\n",
    "lenet_model = LeNet5(10)\n",
    "\n",
    "# Create a SuperDuperDB model with the LeNet-5 model, preprocess, and postprocess functions\n",
    "# Specify 'preferred_devices' as ('cpu',) indicating CPU preference\n",
    "model = superduper(lenet_model, preprocess=preprocess, postprocess=postprocess, preferred_devices=('cpu',))\n",
    "db.add(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dcf0457e",
   "metadata": {},
   "source": [
    "## Train Model\n",
    "\n",
    "Now we are ready to \"train\" or \"fit\" the model. Trainable models in SuperDuperDB come with a sklearn-like `.fit` method. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7c610c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn.functional import cross_entropy\n",
    "\n",
    "from superduperdb import Metric\n",
    "from superduperdb import Dataset\n",
    "from superduperdb.ext.torch.model import TorchTrainerConfiguration\n",
    "\n",
    "# Fit the model to the training data\n",
    "job = model.fit(\n",
    "    X='img', # Feature matrix used as input data \n",
    "    y='class', # Target variable for training\n",
    "    db=db, # Database used for data retrieval\n",
    "    select=mnist_collection.find(), # Select the dataset from the 'mnist_collection'\n",
    "    configuration=TorchTrainerConfiguration(\n",
    "        identifier='my_configuration', # Unique identifier for the training configuration\n",
    "        objective=cross_entropy, # The objective function (cross-entropy in this case)\n",
    "        loader_kwargs={'batch_size': 10}, # DataLoader keyword arguments, batch size is set to 10\n",
    "        max_iterations=10, # Maximum number of training iterations\n",
    "        validation_interval=5, # Interval for validation during training\n",
    "    ),\n",
    "    metrics=[Metric(identifier='acc', object=lambda x, y: sum([xx == yy for xx, yy in zip(x, y)]) / len(x))], # Define a custom accuracy metric for evaluation during training\n",
    "    validation_sets=[\n",
    "        # Define a validation dataset using a subset of data with '_fold' equal to 'valid'\n",
    "        Dataset(\n",
    "            identifier='my_valid',\n",
    "            select=Collection('mnist').find({'_fold': 'valid'}),\n",
    "        )\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdf5cccb2fe0b97b",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## Monitoring Training Efficiency\n",
    "You can monitor the training efficiency with visualization tools like Matplotlib:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "200d3be1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "# Load the model from the database\n",
    "model = db.load('model', model.identifier)\n",
    "\n",
    "# Plot the accuracy values\n",
    "plt.plot(model.metric_values['my_valid/acc'])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0199b952",
   "metadata": {},
   "source": [
    "## On-the-fly Predictions\n",
    "\n",
    "After training the model, you can continuously predict on new data as it arrives. By activating a `listener` for the database, the model can make predictions on incoming data changes without having to load all the data client-side. The listen toggle triggers the model to predict based on updates in the incoming data.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0e53249",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.predict(\n",
    "    X='img', # Input feature  \n",
    "    db=db,  # Database used for data retrieval\n",
    "    select=mnist_collection.find(), # Select the dataset\n",
    "    listen=True, # Continuous predictions on incoming data \n",
    "    max_chunk_size=100, # Number of predictions to return at once\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7daae786",
   "metadata": {},
   "source": [
    "We can see that predictions are available in `_outputs.img.lenet5`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc71a143",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute find_one() to retrieve a single document from the 'mnist_collection'. \n",
    "r = db.execute(mnist_collection.find_one({'_fold': 'valid'}))\n",
    "\n",
    "# Unpack the document and extract its content\n",
    "r.unpack()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a78a2a1",
   "metadata": {},
   "source": [
    "## Verification\n",
    "\n",
    "The models \"activated\" can be seen here:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a5308f4a158c931",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show the status of the listener\n",
    "db.show('listener')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dee36a804224cbb6",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "We can verify that the model is activated, by inserting the rest of the data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1aa56d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Iterate over the last 100 elements in the 'data' list\n",
    "for r in data[-100:]:\n",
    "    # Update the 'update' field to True for each document\n",
    "    r['update'] = True\n",
    "\n",
    "# Insert the updated documents (with 'update' set to True) into the 'mnist_collection'\n",
    "db.execute(mnist_collection.insert_many(data[-100:]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9eb48a30",
   "metadata": {},
   "source": [
    "You can see that the inserted data, are now also populated with predictions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8161983",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute find_one() to retrieve a single sample document from 'mnist_collection'\n",
    "# where the 'update' field is True\n",
    "sample_document = db.execute(mnist_collection.find_one({'update': True}))['_outputs']\n",
    "\n",
    "# A sample document\n",
    "print(sample_document)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
